{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {},
   "outputs": [],
   "source": [
    "from handcalcs.decorator import handcalc\n",
    "import numpy as np"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {},
   "outputs": [],
   "source": [
    "@handcalc\n",
    "def alpha_beta_search(hash_table, start_time, time_out, memory_out, chessboard, current_color, remaining_depth=6,\n",
    "                      rounds=4,\n",
    "                      alphas=np.array([-np.inf, -np.inf])):\n",
    "    \"\"\"\n",
    "\n",
    "    :param rounds:\n",
    "    :param hash_table: Dict.empty(key_type=types.unicode_type, value_type=types.float64[:])\n",
    "    :param start_time:\n",
    "    :param time_out:\n",
    "    :param memory_out:\n",
    "\n",
    "    :param chessboard:\n",
    "    :param current_color:\n",
    "    :param remaining_depth: 0 表示 直接对节点估值，不合法。 1表示一层贪心。 根据时间资源和回合数，请合理分配搜索深度。目前知道10层全回合OK的\n",
    "    :param alphas: 0:到目前为止，路径上发现的 color=-1这个agent 的最佳选择值\n",
    "                  1:到目前为止，路径上发现的 color= 1这个agent 的最佳选择值\n",
    "    :return: 返回对于chessboard,color这个节点，它最大的选择值是多少，以及它选择了哪个子节点。\n",
    "    \"\"\"\n",
    "    # if time.time() - start_time >= time_out:\n",
    "    #     raise TimeoutError(\"Too deep, 速回！\")\n",
    "    alphas = alphas.copy()  # 防止修改上面的alphas\n",
    "    if is_terminal(chessboard):\n",
    "        utility = current_color * get_winner(chessboard)  # winner的颜色和我相等，就是1（颜色的平方性质）， 和我的颜色不等，就是-1.\n",
    "        return symmetry_normalized_value(-1, 1, utility), None  # 满足截断性。由于其他价值函数也归一化了，-1和1就是最小值和最大值。 满足对手对称性。\n",
    "\n",
    "    acts = typed.List(actions(chessboard, current_color))\n",
    "    if len(acts) == 0:\n",
    "        # 只能选择跳过这个action，value为对方的value\n",
    "        value, move = alpha_beta_search(hash_table, start_time, time_out, memory_out, chessboard, -current_color,\n",
    "                                        remaining_depth - 1, rounds + 1, alphas)\n",
    "        value = -value\n",
    "        if len(hash_table) < memory_out:\n",
    "            hash_table[hash_board(chessboard)] = value * current_color\n",
    "        return value, None  # 对手的值是和我反的。 我方没有action可以做。\n",
    "    new_chessboards = typed.List(\n",
    "        [updated_chessboard(chessboard, current_color, a) for a in acts])  # 用最多10倍内存换一半时间（排序和实际操作共用结果）\n",
    "    insertion_sort(acts, new_chessboards, current_color, hash_table, rounds)\n",
    "\n",
    "    if remaining_depth <= 1:  # 比如要求搜索1层，就是直接对max节点的所有邻接节点排序返回最大的。\n",
    "        v = 组合策略1(new_chessboards[0], current_color, rounds)\n",
    "        if v == 1:\n",
    "            print(\"examine this!\")\n",
    "        return v, acts[0]  # 评价永远是根据我方的棋盘\n",
    "\n",
    "    value, move = -np.inf, None  # 写在一起。每个节点都尝试让自己的价值最大化\n",
    "    this_color_idx, other_color_idx = int((current_color + 1) // 2), int((-current_color + 1) // 2)\n",
    "    for i, new_chessboard in enumerate(new_chessboards):\n",
    "        action = acts[i]\n",
    "\n",
    "        new_value, t = alpha_beta_search(hash_table, start_time, time_out, memory_out, new_chessboard, -current_color,\n",
    "                                         remaining_depth - 1, rounds + 1, alphas)\n",
    "        new_value = -new_value\n",
    "        if len(hash_table) < memory_out:\n",
    "            hash_table[hash_board(new_chessboard)] = new_value * current_color\n",
    "\n",
    "        if new_value > value:\n",
    "            value, move = new_value, action\n",
    "            alphas[this_color_idx] = max(alphas[this_color_idx], value)\n",
    "        # 另一种颜色的某一个节点已经到达了c = -beta的水平，低于c的都不接受。\n",
    "        # 而我这个节点，至少可以达到v的水平。\n",
    "        # 在那个对手节点看来，我至多会选择-v， 如果它自己的c已经比我这个-v大了，\n",
    "        # 他就不会考虑我，我被剪枝，随便返回一个我的值和选择。\n",
    "        if -value <= alphas[other_color_idx]:\n",
    "            return value, move\n",
    "    return value, move"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3.10.5 64-bit",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.10.5"
  },
  "orig_nbformat": 4,
  "vscode": {
   "interpreter": {
    "hash": "d4ad195ff334c471a543b0a7bb226f1a689063219ec9cc66cee2dec60707a1ec"
   }
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
